jax-metal segfaults when running Gemma inference

I tried running inference with the 2B model from https://github.com/google-deepmind/gemma on my M2 MacBook Pro, but it segfaults during sampling: https://pastebin.com/KECyz60T

Note: out of the box it will try to load bfloat16 weights, which will fail. To avoid this, I patched line 30 in gemma/params.py to explicitly cast to float32:

  param_state = jax.tree_util.tree_map(lambda p: jnp.array(p, jnp.float32), params)

Can you pls provide the steps and the script to reproduce it?

To reproduce, first download the model checkpoint from https://www.kaggle.com/models/google/gemma/flax/2b-it

Clone the repository and install the dependencies:

git clone https://github.com/google-deepmind/gemma.git
cd gemma
python3 -m venv .
./bin/pip install jax-metal absl-py sentencepiece orbax chex flax

Patch it to use float32 params:

sed -i.bu 's/param_state = jax.tree_util.tree_map(jnp.array, params)/param_state = jax.tree_util.tree_map(lambda p: jnp.array(p, jnp.float32), params)/' gemma/params.py

Run sampling and observe the segfault (paths here must reference the checkpoint downloaded in the first step):

PYTHONPATH=$(pwd) ./bin/python3 examples/sampling.py --path_checkpoint ~/models/gemma_2b_it/2b-it --path_tokenizer ~/models/gemma_2b_it/tokenizer.model
jax-metal segfaults when running Gemma inference
 
 
Q